{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "gv3ILgZGywAc"
   },
   "source": [
    "# **Overview**\n",
    "This notebook benchmarks the MONAI's implementation of global mutual information ANTsPyx's implementation."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "emx2G8VNyfVg"
   },
   "source": [
    "# **Global Mutual Information**\n",
    "Mutual information is an entropy-based measure of image alignment derived from probabilistic measures of image intensity\n",
    "values. Because a large number of image samples are used to estimate image statistics, the effects of image noise on the\n",
    "metric are attenuated. Mutual information is also robust under varying amounts of image overlap as the test image moves\n",
    "with respect to the reference. [1]\n",
    "\n",
    "Formally, the mutual information between two images `A` and `B` is defined as the following\n",
    "\n",
    "<img src=https://latex.codecogs.com/svg.image?I(a%2Cb)%26space%3B%3D%26space%3B%5Csum_%7Ba%2Cb%7D%26space%3Bp(a%2Cb)%26space%3B%5Clog(%5Cfrac%7Bp(a%2Cb)%7D%7Bp(a)p(b)%7D)>\n",
    "\n",
    "where `a` and `b` respectively refers to intensity bin centers of `A` and `B`.\n",
    "\n",
    "We used Parzen windowing in our implementation - given a set of `n` samples in image `A`, each sample `x` contributes to \n",
    "`p(a)` with a function of its intensity and the bin centre `a`:\n",
    "\n",
    "<img src=https://latex.codecogs.com/svg.image?p(a)%3D%26space%3B%5Cfrac%7B1%7D%7Bn%7D%26space%3B%5Csum_%7Bx%26space%3B%5Cin%26space%3BA%7D%26space%3BW(x%2C%26space%3Ba)>\n",
    "\n",
    "Similarly:\n",
    "\n",
    "<img src=https://latex.codecogs.com/svg.image?p(b)%3D%26space%3B%5Cfrac%7B1%7D%7Bn%7D%26space%3B%5Csum_%7By%26space%3B%5Cin%26space%3BB%7D%26space%3BW(y%2C%26space%3Bb)>\n",
    "\n",
    "To compute the joint distribution, we treat each sample as a pair of intensities of corresponding locations in the two images:\n",
    "\n",
    "<img src=https://latex.codecogs.com/svg.image?p(a%2Cb)%26space%3B%3D%26space%3B%5Cfrac%7B1%7D%7Bn%7D%5Csum_%7B(x%2Cy)%5Cin(A%2CB)%7D%26space%3BW(x%2Ca)W(y%2Cb)%26space%3B>\n",
    "\n",
    "\n",
    "Two weighting functions - ``\"gaussian\"`` and ``\"b-spline\"`` - are provided. \n",
    "Here, we compare our ``\"b-spline\"`` method with the validated [ANTsPy](https://antspy.readthedocs.io/en/latest/) \n",
    "library.\n",
    "\n",
    ">[1] \"PET-CT Image Registration in the Chest Using Free-form Deformations\"\n",
    "D. Mattes, D. R. Haynor, H. Vesselle, T. Lewellen and W. Eubank\n",
    "IEEE Transactions in Medical Imaging. Vol.22, No.1,\n",
    "January 2003. pp.120-128. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "a3hamU0vz5QY"
   },
   "source": [
    "# **Setup enviornment**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "0_v9XIO0z4jq",
    "outputId": "27e4eabb-0e7a-4d37-92e3-39f098ae1b56"
   },
   "outputs": [],
   "source": [
    "!python -c \"import monai\" || pip install -q \"monai-weekly[nibabel]\"\n",
    "!python -c \"import ants\" || pip install -q antspyx==0.29\n",
    "!python -c \"import plotly\" || pip install -q plotly==5.3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "spcc50wTqYkZ"
   },
   "outputs": [],
   "source": [
    "import ants\n",
    "import os\n",
    "import tempfile\n",
    "import torch\n",
    "import plotly.graph_objects as go\n",
    "import numpy as np\n",
    "from monai import transforms\n",
    "from monai.apps.utils import download_url\n",
    "from monai.losses import GlobalMutualInformationLoss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "xL0nYrF-q0du",
    "outputId": "5a53430e-14f7-492a-8c56-4ddb01cf0516"
   },
   "outputs": [],
   "source": [
    "# Copyright 2020 MONAI Consortium\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#     http://www.apache.org/licenses/LICENSE-2.0\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License.\n",
    "from monai.config import print_config\n",
    "\n",
    "print_config()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1YIQidqErCZ0"
   },
   "source": [
    "# **Download data**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "GTgU9PMbrm6z",
    "outputId": "2da2d797-d1e4-4553-e9d6-a36f010cd1c6"
   },
   "outputs": [],
   "source": [
    "directory = os.environ.get(\"MONAI_DATA_DIRECTORY\")\n",
    "root_dir = tempfile.mkdtemp() if directory is None else directory\n",
    "print(f\"root dir is: {root_dir}\")\n",
    "file_url = \"https://drive.google.com/uc?id=17tsDLvG_GZm7a4fCVMCv-KyDx0hqq1ji\"\n",
    "file_path = f\"{root_dir}/Prostate_T2W_AX_1.nii\"\n",
    "download_url(file_url, file_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "_rz8fAg_uIqa"
   },
   "source": [
    "# **Comparison**\n",
    "Both ANTsPy's and our implementation follows [1] - a third order BSpline kernel is used for the pred image intensity PDF\n",
    "and a zero order (box car) BSpline kernel is used for the target image intensity PDF.\n",
    "\n",
    "For benchmarking, we set the number of bins to 32, same as\n",
    "[ANTsPy implementation](https://github.com/ANTsX/ANTsPy/blob/master/ants/lib/LOCAL_antsImageMutualInformation.cxx).\n",
    "\n",
    "We took a lower-pelvic 3d MRI as `a1` and transformed it to get `a2` and report the \n",
    "Global Mutual Information between `a1` and `a2` derived with ANTsPy's and our implementation. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "81Sm_mAWtL46"
   },
   "source": [
    "Here, we first initialise a few functions necessary for comparison"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "xwq56csXq2RF"
   },
   "outputs": [],
   "source": [
    "def transformation(\n",
    "        translate_params=(0., 0., 0.),\n",
    "        rotate_params=(0., 0., 0.),\n",
    "):\n",
    "    \"\"\"\n",
    "    Read and transform Prostate_T2W_AX_1.nii\n",
    "    Args:\n",
    "        translate_params: a tuple of 3 floats, translation is in pixel/voxel relative to the center of the input image.\n",
    "                Defaults to no translation.\n",
    "        rotate_params: a rotation angle in radians, a tuple of 3 floats for 3D.\n",
    "                Defaults to no rotation.\n",
    "    Returns:\n",
    "        numpy array of shape HWD\n",
    "    \"\"\"\n",
    "    transform_list = [\n",
    "        transforms.LoadImaged(keys=\"img\"),\n",
    "        transforms.Affined(\n",
    "            keys=\"img\",\n",
    "            translate_params=translate_params,\n",
    "            rotate_params=rotate_params,\n",
    "            as_tensor_output=False,\n",
    "            device=None,\n",
    "        ),\n",
    "        transforms.NormalizeIntensityd(keys=[\"img\"])\n",
    "    ]\n",
    "    transformation = transforms.Compose(transform_list)\n",
    "    return transformation({\"img\": file_path})[\"img\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "LKGg1lALuJDY"
   },
   "outputs": [],
   "source": [
    "def get_result(a1, a2):\n",
    "    \"\"\"\n",
    "    Calculate mutual information with both ANTsPyx and MONAI implementation\n",
    "    Args:\n",
    "        a1: numpy array of shape HWD\n",
    "        a2: numpy array of shape HWD\n",
    "    \"\"\"\n",
    "    antspyx_result = ants.image_mutual_information(\n",
    "        ants.from_numpy(a1),\n",
    "        ants.from_numpy(a2)\n",
    "    )\n",
    "    monai_result = GlobalMutualInformationLoss(\n",
    "        kernel_type=\"b-spline\",\n",
    "        num_bins=32,\n",
    "        sigma_ratio=0.015\n",
    "    )(\n",
    "        torch.tensor(a1).unsqueeze(0).unsqueeze(0),\n",
    "        torch.tensor(a2).unsqueeze(0).unsqueeze(0)\n",
    "    ).item()\n",
    "    return antspyx_result, monai_result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "DOTkgv5Yw18t"
   },
   "outputs": [],
   "source": [
    "def plot(x, results, xaxis_title):\n",
    "    \"\"\"\n",
    "    Plot diagram to compare ANTsPyx and MONAI result\n",
    "    Args:\n",
    "        x: list, x_axis values\n",
    "        results: list of list\n",
    "        xaxis_title: str\n",
    "    \"\"\"\n",
    "    data = [\n",
    "        go.Scatter(\n",
    "            x=x,\n",
    "            y=y,\n",
    "            name=n,\n",
    "            mode=\"lines+markers\",\n",
    "            line={'color': color, 'width': 1},\n",
    "        )\n",
    "        for y, n, color in zip(results, ['ANTsPy', 'MONAI'], ['coral', 'cornflowerblue'])\n",
    "    ]\n",
    "    fig = go.Figure(data=data)\n",
    "    fig.update_layout(\n",
    "        xaxis_title=xaxis_title,\n",
    "        yaxis_title='MutualInformation',\n",
    "        yaxis_range=[-2.0, 0.0]\n",
    "    )\n",
    "    fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "k9zmXZP1vKZi"
   },
   "outputs": [],
   "source": [
    "def compare_antspyx_monai(transform_params_list, transform_name):\n",
    "    \"\"\"\n",
    "    Args:\n",
    "        transform_params_list: a list of tuple\n",
    "        transform_name: str\n",
    "    \"\"\"\n",
    "    antspyx_result = []\n",
    "    monai_result = []\n",
    "    # a1 is the original image without translation and rotation\n",
    "    a1 = transformation((0., 0., 0.))\n",
    "\n",
    "    for transform_params in transform_params_list:\n",
    "        # translate/rotate the image to get a2\n",
    "        a2 = transformation(\n",
    "            translate_params=transform_params[0],\n",
    "            rotate_params=transform_params[1]\n",
    "        )\n",
    "        a_r, m_r = get_result(a1, a2)\n",
    "        antspyx_result.append(a_r)\n",
    "        monai_result.append(m_r)\n",
    "\n",
    "    # calculate the transformation euclidean_distance\n",
    "    x = [np.linalg.norm(np.array(translation_param)) for translation_param in transform_params_list]\n",
    "    # sort results by the transformation euclidean distance\n",
    "    antspyx_result = [i for _, i in sorted(zip(x, antspyx_result))]\n",
    "    monai_result = [i for _, i in sorted(zip(x, monai_result))]\n",
    "    x = sorted(x)\n",
    "    plot(\n",
    "        x=x,\n",
    "        results=[antspyx_result, monai_result],\n",
    "        xaxis_title=transform_name,\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "kdsFsUIF7ew8"
   },
   "source": [
    "The following image visualises the 3d MRI after transformed by different translation params: \n",
    "\n",
    "![a](https://i.ibb.co/6X03szZ/translation-vis.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "L2sd-lUbz-_E"
   },
   "source": [
    "**Translation**\n",
    "\n",
    "First, we incrementally increase the translation in all (x, y, z) directions by (1.0, 1.0, 1.0)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 542
    },
    "id": "at1PbfE6z_cX",
    "outputId": "d3c3ca0a-9619-4892-d301-a7742b5575d8"
   },
   "outputs": [],
   "source": [
    "transform_params_list = [((i, i, i), (0., 0., 0.))for i in range(10)]\n",
    "compare_antspyx_monai(transform_params_list, \"xyz_translation\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "_95Yp0fe40M1"
   },
   "source": [
    "Then, we translate in single directions by randomly sampled parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 542
    },
    "id": "O9Zh8O_u5OE9",
    "outputId": "e2f2eca9-0d51-4238-d95c-af7303070d9e"
   },
   "outputs": [],
   "source": [
    "transform_params_list = [((np.random.rand() * 10, 0., 0.), (0., 0., 0.))for i in range(10)]\n",
    "compare_antspyx_monai(transform_params_list, \"x_translation\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 542
    },
    "id": "D5UIJ1xN5bYT",
    "outputId": "407585fd-da33-4fd1-df53-74a358439ea6"
   },
   "outputs": [],
   "source": [
    "transform_params_list = [((0., np.random.rand() * 10, 0.), (0., 0., 0.))for i in range(10)]\n",
    "compare_antspyx_monai(transform_params_list, \"y_translation\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 542
    },
    "id": "WtrtaPIL5bkQ",
    "outputId": "6a496a2e-c4fa-4f51-c22c-041bde26d726"
   },
   "outputs": [],
   "source": [
    "transform_params_list = [((0., 0., np.random.rand() * 10), (0., 0., 0.))for i in range(10)]\n",
    "compare_antspyx_monai(transform_params_list, \"z_translation\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "JrU3BBcP5qqN"
   },
   "source": [
    "**Rotation**\n",
    "\n",
    "We also incrementally increase the rotation in all (x, y, z) directions by (1.0, 1.0, 1.0)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 542
    },
    "id": "J8UWPiZ75-T8",
    "outputId": "5bc77d1f-c2bd-4a2a-d58c-d44730ce00cd"
   },
   "outputs": [],
   "source": [
    "transform_params_list = [((0., 0., 0.), (np.pi / 100 * i, np.pi / 100 * i, np.pi / 100 * i))for i in range(10)]\n",
    "compare_antspyx_monai(transform_params_list, \"rotation\")"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "name": "Untitled0.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
